import numpy as np
import pickle
import matplotlib.pyplot as plt
from random import random, randint


def calc_corrs(arr, nnn=False):
    corrs = 0
    sl = arr.shape[0]
    for i in range(sl):
        for j in range(sl):
            corrs += corr_at(arr, i, j, nnn=nnn)
    return corrs


def corr_at(mat, i, j, nnn=False):
    sl = mat.shape[0]
    if nnn:
        return mat[i % sl, j % sl] * \
                (
                    mat[(i + 1) % sl, (j + 1) % sl] +
                    mat[(i - 1) % sl, (j - 1) % sl] +
                    mat[(i - 1) % sl, (j + 1) % sl] +
                    mat[(i + 1) % sl, (j - 1) % sl]
                ) / 4.
    return mat[i % sl, j % sl] * \
            (
                mat[(i + 1) % sl, j % sl] +
                mat[i % sl, (j + 1) % sl] +
                mat[(i - 1) % sl, j % sl] +
                mat[i % sl, (j - 1) % sl]
            ) / 4.


def rand_dens(dens):
    if random() < dens:
        return 2 * (random() < 0.5) - 1
    return 0


def gen_infinite_temperature(sl, free_holes, doub_hole, proj):
    """
    This function generates a fake hubbard system at infinite spin temperature with the given free
    hole and doublon-hole pair percentage given
    :param sl: side length of return matrix
    :param free_holes: density of free holes
    :param doub_hole: density of doublon-hole pairs (0.1 means 10% doublon-hole pairs or 20% holes!)
    :return: matrix of occupations. 1 and -1 are spin states, 0 is not occupied
    """

    rand_dens_v = np.vectorize(rand_dens)
    mat = rand_dens_v(np.ones((sl, sl)))

    free_holes_remaining = int(free_holes * sl ** 2)
    doub_hole_remaining = int(doub_hole * sl ** 2)

    while free_holes_remaining > 0 or doub_hole_remaining > 0:
        (i, j) = (randint(0, sl - 1), randint(0, sl - 1))
        if mat[i, j] == 0:
            continue

        if random() < float(free_holes_remaining) \
            / (free_holes_remaining + doub_hole_remaining):

            # placing free hole

            mat[i, j] = 0
            free_holes_remaining -= 1
        else:

            # placing hole + doublon

            direction = randint(0, 3)
            if direction == 0 and mat[(i + 1) % sl, j] != 0:
                mat[i, j] = 0
                mat[(i + 1) % sl, j] = 0
                doub_hole_remaining -= 1
            elif direction == 1 and mat[(i - 1) % sl, j] != 0:
                mat[i, j] = 0
                mat[(i - 1) % sl, j] = 0
                doub_hole_remaining -= 1
            elif direction == 2 and mat[i, (j + 1) % sl] != 0:
                mat[i, j] = 0
                mat[i, (j + 1) % sl] = 0
                doub_hole_remaining -= 1
            elif direction == 3 and mat[i, (j - 1) % sl] != 0:
                mat[i, j] = 0
                mat[i, (j - 1) % sl] = 0
                doub_hole_remaining -= 1

    return mat


def gen_fake(
        sl,
        free_holes,
        doub_hole,
        sz_nn,
        sz_nnn,
        reltol=1e-2,
        abstol=1e-4,
        fun = gen_infinite_temperature,
        proj = 1
    ):
    """
    This function generates a fake hubbard system with the given density, nearest neighbor, and next
    nearest neighbor correlation functions. Doublons and holes are treated the same.
    :param sl: side length of return matrix
    :param free_holes: density of free holes
    :param doub_hole: density of doublon-hole pairs (0.1 means 10% doublon-hole pairs or 20% holes!)
    :param sz_nn: target nearest neighbor correlator
    :param sz_nnn:  target next nearest neighbor correlator
    :param reltol: relative tolerance for target
    :param abstol: absolute tolerance for target
    :return: matrix of occupations. 1 and -1 are spin states, 0 is not occupied
    """

    mat = fun(sl, free_holes, doub_hole, proj)

    nn = sz_nn * sl ** 2
    if sz_nnn is not None:
        nnn = sz_nnn * sl ** 2

    cur_sz = calc_corrs(mat)
    if sz_nnn is not None:
        cur_sz_nnn = calc_corrs(mat, nnn=True)

    while True:
        if sz_nnn is not None:
            if abs(cur_sz - nn) < abstol * sl ** 2 and \
               abs(cur_sz_nnn - nnn) < abstol * sl ** 2 and \
               abs(cur_sz - nn) / float(nn) < reltol and \
               abs(cur_sz_nnn - nnn) / float(nnn) < reltol:
                break
        else:
            if abs(cur_sz - nn) < abstol * sl ** 2 and \
               abs(cur_sz - nn) / float(nn) < reltol:
                break

        (i, j) = (randint(0, sl - 1), randint(0, sl - 1))
        if mat[i, j] == 0:
            continue

        delta_nn = corr_at(mat, i, j) + \
                   corr_at(mat, i + 1, j) + \
                   corr_at(mat, i - 1, j) + \
                   corr_at(mat, i, j + 1) + \
                   corr_at(mat, i, j - 1)

        if sz_nnn is not None:
            delta_nnn = corr_at(mat, i, j, nnn=True) + \
                        corr_at(mat, i + 1, j + 1, nnn=True) + \
                        corr_at(mat, i - 1, j - 1, nnn=True) + \
                        corr_at(mat, i - 1, j + 1, nnn=True) + \
                        corr_at(mat, i + 1, j - 1, nnn=True)

        mat[i, j] = -mat[i, j]

        delta_nn -= corr_at(mat, i, j) + \
                    corr_at(mat, i + 1, j) + \
                    corr_at(mat, i - 1, j) + \
                    corr_at(mat, i, j + 1) + \
                    corr_at(mat, i, j - 1)

        if sz_nnn is not None:
            delta_nnn -= corr_at(mat, i, j, nnn=True) + \
                         corr_at(mat, i + 1, j + 1, nnn=True) + \
                         corr_at(mat, i - 1, j - 1, nnn=True) + \
                         corr_at(mat, i - 1, j + 1, nnn=True) + \
                         corr_at(mat, i + 1, j - 1, nnn=True)

        if sz_nnn is not None:
            if (cur_sz - nn) ** 2 + (cur_sz_nnn - nnn) ** 2 < \
               (cur_sz - delta_nn - nn) ** 2 + (cur_sz_nnn - delta_nnn - nnn) ** 2:
                mat[i, j] = -mat[i, j]
            else:
                cur_sz -= delta_nn
                cur_sz_nnn -= delta_nnn
        else:
            if (cur_sz - nn) ** 2 < (cur_sz - delta_nn - nn) ** 2:
                mat[i, j] = -mat[i, j]
            else:
                cur_sz -= delta_nn

    # print 'nn', cur_sz, calc_corrs(mat) / float(sl ** 2), nn
    # print 'nnn', cur_sz_nnn, calc_corrs(mat, nnn=True) / float(sl ** 2), nnn

    return mat


def gen_fake_hubbard(
        sl,
        free_holes,
        doub_hole,
        sz_nn,
        sz_nnn,
        reltol=1e-2,
        abstol=1e-4):
    return gen_fake(
        sl,
        free_holes,
        doub_hole,
        sz_nn,
        sz_nnn,
        reltol=reltol,
        abstol=abstol,
        fun=gen_infinite_temperature)


# From classical AFM; avoids overlapping singlets or dh pairs; noblow, blow1, blow2
def phenom_fromafm(fn_template, fn_out, num_ams, cths=1.0, singlets=0.0, doub_hole=0.04, free_holes=0.0, imgfid=1.0):

    with open(fn_template, 'rb') as f:
        ams = pickle.load(f)

    mask = np.zeros_like(ams[0][0]) + ams[0][0] != 0
    cthss = np.random.uniform(cths, 1, num_ams)  # Have a spread of cos(theta/2)**2 which are possible

    ams_full = []

    for l in range(3):
        ams = np.zeros((num_ams, mask.shape[0], mask.shape[1]))

        for k in range(num_ams):
            # Start with a perfect checkerboard
            re = np.r_[mask.shape[0]/2*[-1, 1]]
            ro = np.r_[mask.shape[0]/2*[1, -1]]
            am = np.row_stack(mask.shape[1]/2*[re, ro])

            # Put in free holes, doublon_holes, and singlets
            free_holes_remaining = int(free_holes * mask.shape[0] * mask.shape[1])
            doub_hole_remaining = int(doub_hole * mask.shape[0] * mask.shape[1])
            singlets_remaining = int(singlets * mask.shape[0] * mask.shape[1])

            while free_holes_remaining > 0 or doub_hole_remaining > 0 or singlets_remaining > 0:
                (i, j) = (randint(0, mask.shape[0] - 1), randint(0, mask.shape[1] - 1))
                if am[i, j] == 0 or abs(am[i, j]) == 2:  # use +/-2 for singlets
                    continue

                r = random()
                if r < float(free_holes_remaining) \
                        / (free_holes_remaining + doub_hole_remaining + singlets_remaining):

                    # placing free hole
                    am[i, j] = 0
                    free_holes_remaining -= 1
                elif r < float(free_holes_remaining + doub_hole_remaining)\
                         / (free_holes_remaining + doub_hole_remaining + singlets_remaining):

                    # placing hole + doublon
                    direction = randint(0, 3)
                    if direction == 0 and (am[(i + 1) % mask.shape[0], j] != 0 or abs(am[(i + 1) % mask.shape[0], j]) != 2):
                        am[i, j] = 0
                        am[(i + 1) % mask.shape[0], j] = 0
                        doub_hole_remaining -= 1
                    elif direction == 1 and (am[(i - 1) % mask.shape[0], j] != 0 or abs(am[(i - 1) % mask.shape[0], j]) != 2):
                        am[i, j] = 0
                        am[(i - 1) % mask.shape[0], j] = 0
                        doub_hole_remaining -= 1
                    elif direction == 2 and (am[i, (j + 1) % mask.shape[1]] != 0 or abs(am[i, (j + 1) % mask.shape[1]]) != 2):
                        am[i, j] = 0
                        am[i, (j + 1) % mask.shape[1]] = 0
                        doub_hole_remaining -= 1
                    elif direction == 3 and (am[i, (j - 1) % mask.shape[1]] != 0 or abs(am[i, (j - 1) % mask.shape[1]]) != 2):
                        am[i, j] = 0
                        am[i, (j - 1) % mask.shape[1]] = 0
                        doub_hole_remaining -= 1

                else:

                    # placing singlets
                    direction = randint(0, 3)
                    flip = randint(0, 1) * 4 - 2  # -2 or 2
                    if direction == 0 and (am[(i + 1) % mask.shape[0], j] != 0 or abs(am[(i + 1) % mask.shape[0], j]) != 2):
                        am[i, j] = flip
                        am[(i + 1) % mask.shape[0], j] = -flip
                        singlets_remaining -= 1
                    elif direction == 1 and (am[(i - 1) % mask.shape[0], j] != 0 or abs(am[(i - 1) % mask.shape[0], j]) != 2):
                        am[i, j] = flip
                        am[(i - 1) % mask.shape[0], j] = -flip
                        singlets_remaining -= 1
                    elif direction == 2 and (am[i, (j + 1) % mask.shape[1]] != 0 or abs(am[i, (j + 1) % mask.shape[1]]) != 2):
                        am[i, j] = flip
                        am[i, (j + 1) % mask.shape[1]] = -flip
                        singlets_remaining -= 1
                    elif direction == 3 and (am[i, (j - 1) % mask.shape[1]] != 0 or abs(am[i, (j - 1) % mask.shape[1]]) != 2):
                        am[i, j] = flip
                        am[i, (j - 1) % mask.shape[1]] = -flip
                        singlets_remaining -= 1

            # Projective measurement. -1 is flip; 1 is not flip.
            p = (np.random.uniform(0, 1, mask.shape) < cthss[i])*2-1
            # Add in singlets, which are unaffected by projective measurement.
            p = np.where(abs(am) == 2, 1, p)
            am *= p

            # Imaging fidelity. lose a particle with probability imgfid
            am = np.where(np.random.uniform(0, 1, mask.shape) < imgfid, am, -1)

            # Turn singlets into spins
            am = np.where(abs(am) == 2, am/2, am)

            if l != 0:
                ams[k] = np.where(am == 0, -1, am) * mask  # make 0s into -1s, because that's how we detect
            else:
                ams[k] = np.where(am == 0, -1, abs(am)) * mask  # make 0s into -1s, but -1s into 1s for density shot

        ams_full.append(ams.tolist())

    with open(fn_out, 'wb') as f:
        pickle.dump(ams_full, f)


# Matched correlators from random spins; noblow, blow1, blow2
def phenom_matchcorr(fn_template, fn_out, num_ams, fh, dh, nn, nnn):

    with open(fn_template, 'rb') as f:
        ams = pickle.load(f)

    # cut out window from system
    roi = np.where(ams[0][0] != 0)
    # find corners of big roi
    xMin = min(np.unique(roi[0]))
    xMax = max(np.unique(roi[0]))
    yMin = min(np.unique(roi[1]))
    yMax = max(np.unique(roi[1]))

    m1 = ams[0][0][xMin:xMax + 1, yMin:yMax + 1]

    mask = np.zeros_like(m1) + m1 != 0
    ams = np.zeros((num_ams, mask.shape[0], mask.shape[1]))

    sl = ams[0][0].shape[0]

    ams_full = []

    for l in range(3):
        for i in range(num_ams):
            am = gen_fake_hubbard(sl, fh, dh, nn, nnn, abstol=1e-3)  # 1, -1 spin states; 0 not occupied
            if l != 0:
                ams[i] = np.where(am == 0, -1, am) * mask  # make 0s into -1s, because that's how we detect
            else:
                ams[i] = np.where(am == 0, -1, abs(am)) * mask  # make 0s into -1s, but -1s into 1s for density shot

        ams_full.append(ams.tolist())

    with open(fn_out, 'wb') as f:
        pickle.dump(ams_full, f)


if __name__ == '__main__':
    # parameters go here
    templatename = 'cold_D32.pkl'

    sel = 'matchcorr'  # 'matchcorr' or 'fromafm'
    num_ams = 100      # number of atom matrices per (dens, upspins, downspins)
    cths = np.cos(1.0*np.pi/2)**2  # [0, 1], cos(theta/2)**2
    imgfid = 1.00      # imaging fidelity

    # both
    fh = [0.00, 0.10]
    dh = [0.04, 0.03]
    # matched correlators
    nn = [-0.25, -0.17]
    nnn = [None, None]  # [0.12, 0.04]  # be warned, including nnn adds some time. consider changing abstol
    # from classical afm
    singlets = [0.30, 0.30]  # fraction of 'singlets'

    if sel == 'matchcorr':
        if type(nn) is not list:
            if nnn is not None:
                on = '{}_{:1.02f}-{:1.02f}_{:1.02f}-{:1.02f}.pkl'.format(sel, fh, dh, nn, nnn)
            else:
                on = '{}_{:1.02f}-{:1.02f}_{:1.02f}-None.pkl'.format(sel, fh, dh, nn)
            phenom_matchcorr(templatename, on, num_ams, fh, dh, nn, nnn)
        else:
            for f, d, n2, n3 in zip(fh, dh, nn, nnn):
                if n3 is not None:
                    on = '{}_{:1.02f}-{:1.02f}_{:1.02f}-{:1.02f}.pkl'.format(sel, f, d, n2, n3)
                else:
                    on = '{}_{:1.02f}-{:1.02f}_{:1.02f}-None.pkl'.format(sel, f, d, n2)
                phenom_matchcorr(templatename, on, num_ams, f, d, n2, n3)

    elif sel == 'fromafm':
        if type(singlets) is not list:
            on = '{}_{:1.02f}-{:1.02f}_{:1.02f}.pkl'.format(sel, fh, dh, singlets)
            phenom_fromafm(templatename, on, num_ams, cths, singlets, dh, fh, imgfid)
        else:
            for f, d, s in zip (fh, dh, singlets):
                on = '{}_{:1.02f}-{:1.02f}_{:1.02f}.pkl'.format(sel, f, d, s)
                phenom_fromafm(templatename, on, num_ams, cths, s, d, f, imgfid)

    else:
        print "error- variable sel should be set to 'matchcorr' or 'fromafm'"

